import os
import pickle
from typing import Any, Optional

import econml
import networkx as nx
import numpy as np
import pandas as pd
from memory_profiler import memory_usage
from omegaconf import DictConfig, OmegaConf
import argparse
from castle.metrics import MetricsDAG

from data.load_data import load_data, load_ground_truth
from causal_discovery import learn_structure
from dowhy_experiment import run_dowhy_experiment
from custom_model_experiment import run_custom_model
from utils.graph import save_causal_graph
from utils.data import get_ground_truth_distributions, preprocess_data
from utils.metrics import seed_metrics
from utils.utils import parse_dict
import yaml


def run_seed(
    cfg: DictConfig,
    save_path: str,
    seed: int,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Run a causal experiment using the given inputs and return the results and runtime measurements.

    Parameters:
        cfg (DictConfig): The configuration object containing all the necessary data.
        save_path (str): The directory to save the experiment results.
        seed (int): The seed value for random number generation.

    Returns:
        tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing the results and runtime measurements.
    """

    # Create seed path
    seed_path = os.path.join(save_path, f"seed_{seed}")
    if not os.path.exists(seed_path):
        os.makedirs(seed_path)

    (
        outcome,
        evidence,
        treatment,
        ground_truth_graph,
        data,
        int_table,
        ground_truth,
        gt_distributions,
    ) = setup_seed(cfg)

    del ground_truth

    if not cfg.causal_discovery == "ground_truth":
        learned_causal_graph, discovery_time = learn_structure(cfg, data, int_table, save_path=seed_path, seed = seed)

        met = MetricsDAG(
            nx.adjacency_matrix(learned_causal_graph).todense(),
            nx.adjacency_matrix(ground_truth_graph).todense(),
        )

        # Convert to dataframe and save it
        cd_metrics = met.metrics.copy()
        cd_metrics.update({"Seed": seed,
                           "Discovery Time": discovery_time,
                           "Causal Discovery": cfg.causal_discovery,
                            "Dataset Size": cfg.dataset_size,
                           })
        cd_metrics_df = pd.DataFrame([cd_metrics])
        cd_metrics_df.to_csv(os.path.join(seed_path, f"cd_metrics_RS{seed}.csv"), index=False)

        causal_graph = learned_causal_graph
        save_causal_graph(learned_causal_graph, seed_path, "learned_causal_graph_dag")

    else:
        causal_graph = ground_truth_graph

    if cfg.only_causal_discovery:
        print("Only causal discovery is enabled. Skipping treatment estimation.")
        if not nx.has_path(causal_graph, treatment[0]["treatment_var"], outcome):
            print(
                f"No path between treatment variable {treatment[0]['treatment_var']} and target variable {outcome}"
            )
        return pd.DataFrame(), pd.DataFrame()
    
    # Check if there is a path between two nodes in the causal graph
    if not nx.has_path(causal_graph, treatment[0]["treatment_var"], outcome):
        print(
            f"No path between treatment variable {treatment[0]['treatment_var']} and target variable {outcome}"
        )
        return pd.DataFrame(), pd.DataFrame()

    args = (
        causal_graph,
        cfg.experiment,
        treatment,
        outcome,
        evidence,
        seed,
        data,
        int_table,
        seed_path,
    )
    # results, runtime = run_model(causal_graph, experiment_cfg, treatment, outcome, evidence, seed, data, int_table, seed_path)
    memory_usage_args = (run_model, args)
    delta_t = 0.1
    mem_usage, pkd_results = memory_usage(
        memory_usage_args, retval=True, interval=delta_t
    )
    results_list, runtime_list = pkd_results
    [
        runtime.update({"Avg. Memory Usage": np.average(mem_usage)})
        for runtime in runtime_list
    ]
    [
        runtime.update({"Max. Memory Usage": np.max(mem_usage)})
        for runtime in runtime_list
    ]

    ### EVAL SEED METRICS ###
    # Load ground truth again (Needed for computing MMD)
    (
        outcome,
        evidence,
        treatment_list,
        ground_truth_graph,
        data,
        int_table,
        ground_truth_list,
        gt_distribution_list,
    ) = setup_seed(cfg)

    for k, (treatment, ground_truth, gt_distributions) in enumerate(
        zip(treatment_list, ground_truth_list, gt_distribution_list)
    ):
        treatment_res = results_list[k].copy()

        # Compute metrics
        metrics = seed_metrics(
            treatment_res,
            gt_distributions,
            treatment,
            outcome,
            evidence,
            ground_truth,
        )
        treatment_res.update(metrics)

        if cfg.causal_discovery != "ground_truth":
            treatment_res.update(met.metrics.copy())
            treatment_res.update({"Causal Discovery": cfg.causal_discovery})

        ### SAVE RESULTS ###

        # Save Results in a pandas dataframe
        res_dict = {
            "Seed": seed,
            "ATE": treatment_res["ATE"],
            "CATE": treatment_res["CATE"],
        }
        res_dict.update(metrics)
        results_df = pd.DataFrame(res_dict, index=[0])
        results_df.to_csv(
            os.path.join(seed_path, f"results_treatment{k}_RS{seed}.csv"), index=False
        )

        # Save samples in a dataframe
        save_distributions_and_samples(seed, seed_path, gt_distributions, k, treatment_res)

    # Save runtime information of different treatment estimations in Pandas Dataframes
    for j, runtime in enumerate(runtime_list):
        runtime_df = pd.DataFrame(runtime, index=[0])
        runtime_df.to_csv(os.path.join(seed_path, f"runtime_treatment{j}_RS{seed}.csv"), index=False)

    # Save also the history of memory usage in a csv file
    timeline = np.arange(0, len(mem_usage)) * delta_t
    mem_usage_df = pd.DataFrame({"Time": timeline, "Memory Usage": np.array(mem_usage)})
    mem_usage_df.to_csv(
        os.path.join(seed_path, f"memory_usage_RS{seed}.csv"), index=False
    )

    return results_df, runtime_df

def save_distributions_and_samples(seed, seed_path, gt_distributions, k, treatment_res):
    if "Interventional Samples" in treatment_res:
        int_samples = treatment_res["Interventional Samples"]
        cond_int_samples = treatment_res["Conditional Interventional Samples"]
        int_samples.to_csv(
                os.path.join(seed_path, f"interventional_samples_{k}_RS{seed}.csv"),
                index=False,
            )
        cond_int_samples.to_csv(
                os.path.join(seed_path, f"conditional_interventional_samples_{k}_RS{seed}.csv"),
                index=False,
            )

    if "Interventional Distribution" in treatment_res:
        int_distr = pd.DataFrame(treatment_res["Interventional Distribution"])
        int_distr.index = treatment_res["state_names"]
        int_distr.to_csv(
                os.path.join(seed_path, f"treated_estimated_{k}_RS{seed}.csv")
            )
    
        cond_int_distr = pd.DataFrame(treatment_res["Conditional Interventional Distribution"])
        cond_int_distr.index = treatment_res["state_names"]
        cond_int_distr.to_csv(
                os.path.join(seed_path, f"treated_conditional_estimated_{k}_RS{seed}.csv")
            )
        
    if "Control Samples" in treatment_res:
        control_samples = treatment_res["Control Samples"]
        control_samples.to_csv(
                os.path.join(seed_path, f"control_samples_{k}_RS{seed}.csv"),
                index=False,
            )
    if "Conditional Control Samples" in treatment_res:
        conditional_control_samples = treatment_res["Conditional Control Samples"]
        conditional_control_samples.to_csv(
                os.path.join(seed_path, f"conditional_control_samples_{k}_RS{seed}.csv"),
                index=False,
            )

    if "Control Interventional Distribution" in treatment_res:
        int_distr_control = pd.DataFrame(treatment_res["Control Interventional Distribution"])
        int_distr_control.index = treatment_res["state_names"]
        int_distr_control.to_csv(
                os.path.join(seed_path, f"control_estimated_{k}_RS{seed}.csv")
            )

    if "Conditional Control Distribution" in treatment_res:
        cond_int_distr_control = pd.DataFrame(treatment_res["Conditional Control Distribution"])
        cond_int_distr_control.index = treatment_res["state_names"]
        cond_int_distr_control.to_csv(
                os.path.join(seed_path, f"control_conditional_estimated_{k}_RS{seed}.csv")
            )

        # Save ground truth interventional distributions
    gt_distributions["Treated Interventional Distribution"].to_csv(
            os.path.join(seed_path, f"treated_ground_truth_{k}.csv")
        )
    gt_distributions["Control Interventional Distribution"].to_csv(
            os.path.join(seed_path, f"control_ground_truth_{k}.csv")
        )
    gt_distributions["Treated Conditional Interventional Distribution"].to_csv(
            os.path.join(seed_path, f"treated_conditional_ground_truth_{k}.csv")
        )
    gt_distributions["Control Conditional Interventional Distribution"].to_csv(
            os.path.join(seed_path, f"control_conditional_ground_truth_{k}.csv")
        )


def setup_seed(cfg):
    outcome: str = cfg.outcome
    evidence: dict[str, Any] = parse_dict(cfg.evidence)
    treatment: list[dict[str, Any]] = cfg.treatment
    experiment_cfg = cfg.experiment

    if cfg.dataset == "dataset0":
        quant_steps = 20
    elif cfg.dataset == "dataset1":
        quant_steps = 10

    # Load data
    ground_truth_graph, data, int_table = load_data(cfg)
    ground_truth = load_ground_truth(cfg)
    data, int_table, ground_truth, treatment, evidence = preprocess_data(
        data,
        ground_truth,
        treatment,
        evidence,
        quantize=cfg.experiment.quantize_dataset,
        quant_steps=quant_steps,
        dataset_size=cfg.dataset_size,
        use_interventional_data=experiment_cfg.use_interventional_data,
        int_table=int_table,
        balance_variable=cfg.balance_variable,
    )
    gt_distributions = get_ground_truth_distributions(
        ground_truth, treatment, outcome, evidence
    )

    return (
        outcome,
        evidence,
        treatment,
        ground_truth_graph,
        data,
        int_table,
        ground_truth,
        gt_distributions,
    )


def run_model(
    causal_graph,
    experiment_cfg,
    treatment,
    outcome,
    evidence,
    seed,
    data,
    int_table,
    seed_path,
):
    if experiment_cfg.model_source == "custom_model":
        results, runtime = run_custom_model(
            seed,
            causal_graph,
            experiment_cfg,
            treatment,
            outcome,
            data=data,
            int_table=int_table,
            save_dir=seed_path,
            evidence=evidence,
        )
    elif experiment_cfg.model_source == "dowhy":
        results, runtime = run_dowhy_experiment(
            seed,
            causal_graph,
            experiment_cfg,
            treatment,
            outcome,
            data=data,
            int_table=int_table,
            evidence=evidence,
        )
    else:
        raise ValueError(f"Invalid model source: {experiment_cfg.model_source}")
    return results, runtime


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run causal experiment")
    parser.add_argument("--config_path", type=str, help="Path to the config YAML file")
    parser.add_argument(
        "--experiment_dir", type=str, help="Directory to save the experiment results"
    )
    parser.add_argument(
        "--seed", type=int, help="Seed value for random number generation"
    )
    args = parser.parse_args()

    config_path = args.config_path
    experiment_dir = args.experiment_dir
    seed = args.seed

    # Load the config file and run the experiment
    with open(config_path, "r") as f:
        config = OmegaConf.create(yaml.safe_load(f))

    results, runtime = run_seed(config, experiment_dir, seed)

    results.to_csv(os.path.join(experiment_dir, f"results_RS{seed}.csv"), index=False)
    runtime.to_csv(os.path.join(experiment_dir, f"runtime_RS{seed}.csv"), index=False)
    print(f"Seed {seed} completed successfully.")
